(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載)
在 JAX 官方的教學網頁 JAX Quickstart [5.1] 上,開宗明義就說:
JAX 是一個可以跑在 CPU, GPU, 以及 TPU 上的 Numpy, …
(JAX is NumPy on the CPU, GPU, and TPU, … )
怎麼做到的呢?JAX 提供了與 Numpy 幾乎完全一致的 API !這些 API 是在 jax.numpy 之下,而習慣上我們是這樣使用的:
import jax.numpy as jnp
它對應了 Numpy 的習慣用法:
import numpy as np
在大部份的情況下,jax.numpy 的 API 用法,和標準 Numpy API 的用法相同,僅有少數的例外,這些例外,老頭會在日後加以說明。
現在我們先來看看一些簡單的例子:
# import jax.numpy and numpy
import jax.numpy as jnp
import numpy as np
# declare the data
#==========================================================================
# jax.numpy
x = jnp.arange(10)
# numpy
y = np.arange(10)
# show the declared data
#==========================================================================
print(f'Data defined by JAX : {x}')
print(f'Data defined by Numpy : {y}')
output:
Data defined by JAX : [0 1 2 3 4 5 6 7 8 9]
Data defined by Numpy : [0 1 2 3 4 5 6 7 8 9]
# operation: sum
#==========================================================================
# jax.numpy
sum_jnp = jnp.sum(x)
# numpy
sum_np = np.sum(y)
print(f'Sum by JAX : {sum_jnp}')
print(f'Sum by Numpy : {sum_np}')
output:
Sum by JAX : 45
Sum by Numpy : 45
# operation: dot
#==========================================================================
# jax.numpy
dot_jnp = jnp.dot(x,x)
# numpy
dot_np = np.dot(y,y)
print(f'Dot by JAX : {dot_jnp}')
print(f'Dot by Numpy : {dot_np}')
output:
Dot by JAX : 285
Dot by Numpy : 285
在以上的例子,jax.numpy 的 API 和 Numpy 的 API 幾乎是一對一的對應,語法 (syntax) 和語意 (semantics) 也完全相同。
在繼續談 jax.numpy 之前,老頭想先簡單介紹 %timeit 這個魔術指令 (magic command) [5.2]。%timeit 可以計算一列 Python 敍述 (statement) 執行時所需要的時間。例如:
%timeit 100.0 / 5.0
output: 11.2 ns ± 0.0468 ns per loop (mean ± std. dev. of 7 runs, 100000000 loops each)
%timeit 採用了「二層重覆執行」的方式,來得到精確且有統計意義的結果。
第一層:執行 Python 敍述 N 次 (N 個 loops),計算平均時間,得到 Python 敍述執行一次的時間值。
第二層:重複第一層 R 次 (R 個 runs),得到 R 個結果,利用這 R 個結果,可以得到平均數、標準差等統計量。
用選項 -n 來指定第一層的次數,用 -r 來指定第二層的次數,沒有指定的話,則選用既定值 (default values)。
-n : default value 由系統自行判斷要達要適當精確度所需的重覆次數。
-r : default value 7
讀者可能會納悶,為什麼需要第一層的重覆呢?因為由作業系統 (operation system; OS) 所提供的最小計時單位,可能會比一些簡單的敍述所需的時間大得多,如上面的例子,100.0 / 5.0 所需的時間僅需 11.2 ns,這不是作業系統計時器所可以量出來的,所以 %timeit 必須重覆執行它 100000000 loops 才能正確的估算它的執行時間。
做個實驗
%timeit -n 1 -r 1 100.0/5.0
output:
668 ns ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
這個實驗是在 google colab 的 linux 虛擬機 (VM) 上執行的,版本是 5.4.188+
!uname -a
output:
Linux 47cad5f60b70 5.4.188+ #1 SMP Sun Apr 24 10:03:06 PDT 2022 x86_64 x86_64 x86_64 GNU/Linux
linux 能夠支援的計時精度,若導入了高解析度計時功能 (HRT; high-resolution timer)大約只到微秒 (micro-second) 等級,因此只執行 100.0/5.0 所得到結果比實際所需執行時間大得多!
希望這個簡單的介紹能讓大家了解它的基本用法,接下來 %timeit 會被用來計算 jax.numpy 及 Numpy 執行同一個運算所需要的時間,讓大家看看 jax.numpy 到底可以快多少!
註:
[5.1] 參考 JAX Quickstart。
[5.2] 老頭在此不對「魔力指令」多加說明,讀者們可以參考 IPython 的官方文件 Build-in magic commands 。